import difflib
import json
import os.path as osp

from datasets import Dataset

from opencompass.openicl.icl_evaluator.code_evaluator import CodeEvaluator
from opencompass.registry import LOAD_DATASET
from opencompass.utils import get_data_path

from .base import BaseDataset

# currently supporting languages
_HUMANEVAL_LANGUAGE_ = [
    'adb', 'clj', 'cpp', 'cs', 'd', 'dart', 'elixir', 'go', 'hs', 'java', 'jl',
    'js', 'lua', 'ml', 'php', 'pl', 'py', 'r', 'rb', 'rkt', 'rs', 'scala',
    'sh', 'swift', 'ts'
]
_MBPP_LANGUAGE_ = [
    'adb', 'clj', 'cpp', 'cs', 'd', 'elixir', 'go', 'hs', 'java', 'jl', 'js',
    'lua', 'ml', 'php', 'pl', 'py', 'r', 'rb', 'rkt', 'rs', 'scala', 'sh',
    'swift', 'ts'
]


@LOAD_DATASET.register_module()
class MultiplEDataset(BaseDataset):

    @staticmethod
    def load(path: str,
             language: str,
             tag: str = 'humaneval',
             local_mode: bool = False):
        """Load dataset for pass k mode.

        Args:
            path(str): The path to the dataset.
            language(str): The language of the dataset.
            num_repeats(int): Number of repetition for this dataset to get.
            tag(str): The tag of the dataset.
            local_mode(bool): Whether to load the dataset in local mode.

        Returns:
            Dataset: A PyTorch dataset.
        """
        path = get_data_path(path, local_mode=local_mode)
        assert tag in ['humaneval',
                       'mbpp'], 'tag must be in ["humaneval", "mbpp"]'
        if tag == 'humaneval':
            assert language in _HUMANEVAL_LANGUAGE_, (
                f'language must be in {_HUMANEVAL_LANGUAGE_}')
        else:
            assert language in _MBPP_LANGUAGE_, (
                f'language must be in {_MBPP_LANGUAGE_}')
        file_path = osp.join(path, f'{tag}-{language}.jsonl')
        dataset = []
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                dataset.append(json.loads(line.strip()))
        return Dataset.from_list(dataset)


class MultiplEEvaluator(CodeEvaluator):

    def _stop_at_stop_token(self, decoded_string, stop_tokens):
        """Produces the prefix of decoded_string that ends at the first
        occurrence of a stop_token.

        WARNING: the decoded_string *must not* include the prompt,
        which may have stop tokens itself.

        Args:
            decoded_string: A string generated by the model.
            stop_tokens: A list of strings, where each string is a stop token.
        Returns:
            The decoded_string, truncated at the first occurrence of a stop
            token.
        """
        min_stop_index = len(decoded_string)
        for stop_token in stop_tokens:
            stop_index = decoded_string.find(stop_token)
            if stop_index != -1 and stop_index < min_stop_index:
                min_stop_index = stop_index
        return decoded_string[:min_stop_index]

    def _remove_prefix(self,
                       prompt: str,
                       completion: str,
                       threshold: float = 0.95) -> str:
        """Determine the truncation point in the completion based on the last
        line of the prompt, remove all content before that line in the
        completion, and return the completion string after removing the prefix.
        This is done to convert chatbot-style inference mode to completion
        mode.

        Args:
            prompt (str): The prompt text.
            completion (str): The completion text.
            threshold (float): Line similarity threshold.

        Returns:
            str: The completion string after removing the prefix.
        """
        prompt_lines = prompt.splitlines()
        completion_lines = completion.splitlines()

        if not prompt_lines:
            return completion

        last_prompt_line = prompt_lines[-1]
        cut_index = -1

        for i, completion_line in enumerate(completion_lines):
            similarity = difflib.SequenceMatcher(None, last_prompt_line,
                                                 completion_line).ratio()
            if similarity >= threshold:
                cut_index = i
                break

        if cut_index != -1:
            return '\n'.join(completion_lines[cut_index + 1:])
        else:
            return completion

    def _process_completions(self, test_case, completion):
        """Process completions with a test case.

        Args:
            test_case (dict): A test case containing prompt and stop tokens.
            completion (str): The generated code completion.
        Returns:
            str: Processed code completion.
        """
        post_comp = self._extract_code(completion)
        post_comp = self._remove_prefix(test_case['prompt'], post_comp)
        post_comp = self._stop_at_stop_token(post_comp,
                                             test_case['stop_tokens'])
        return post_comp
